# http://proceedings.mlr.press/v101/huang19a/huang19a.pdf
# https://www.researchgate.net/publication/220875351_Generative_Models_for_Labeling_Multi-object_Configurations_in_Images
# https://www.tensorflow.org/datasets/catalog/open_images_v4
# Auto-Encoding Progressive Generative Adversarial Networks For 3D Multi Object Scenes
# TODO
# for data set kitt (as AD case study) - for the built model
# 1. report model loss for validation dataset - Done
# 2. visualize reconstructed images - Done
# 3. Grid search (K, cov type) for gaussian mixture log p comparison (or baysian parameter optimization) - SKIP
# reason: nead to focus on core idea - GM is good other than G in Autonomous driving on a simplified case
# 4. read about inf Gaussian mixture https://www.seas.harvard.edu/courses/cs281/papers/rasmussen-1999a.pdf
datasets to experiment
%config Completer.use_jedi = False
from ipywidgets import IntProgress
import matplotlib.pyplot as plt
from tensorflow.keras import layers, losses
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import logging
import tensorflow_datasets as tfds
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from sklearn.mixture import GaussianMixture
import os
seed = 1
np.random.seed(1)
tf.random.set_seed(1)
batch_size = 32
epochs = 10
dataset_name = 'kitti'
if dataset_name == 'bdd100k':
train_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/train1/',batch_size=batch_size)# train
test_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/test1/',batch_size=batch_size) # test
validation_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/val1/',batch_size=batch_size) # validation
elif dataset_name in ['flic','fashion_mnist','mnist','kitti']:
train_ds,test_ds = tfds.load(name=dataset_name,split=['train', 'test']\
,as_supervised=False,download=True)
validation_ds = test_ds
elif dataset_name in ['wider_face']:
train_ds,test_ds,validation_ds = tfds.load(name=dataset_name,split=['train', 'test','validation']\
,as_supervised=False,download=True)
else:
raise ValueError(f'Unhandled dataset {dataset_name}')
if dataset_name == 'bdd100k':
dims = [x[0].get_shape().as_list() for x in train_ds]
dims_df= pd.DataFrame.from_records(data=dims,columns=['batch','height','width','depth'])
else:
dims = [x['image'].get_shape().as_list() for x in train_ds]
dims_df= pd.DataFrame.from_records(data=dims,columns=['height','width','depth'])
dims_df.describe()
| height | width | depth | |
|---|---|---|---|
| count | 6347.000000 | 6347.000000 | 6347.0 |
| mean | 374.481960 | 1240.112494 | 3.0 |
| std | 1.447946 | 5.220926 | 0.0 |
| min | 370.000000 | 1224.000000 | 3.0 |
| 25% | 375.000000 | 1242.000000 | 3.0 |
| 50% | 375.000000 | 1242.000000 | 3.0 |
| 75% | 375.000000 | 1242.000000 | 3.0 |
| max | 375.000000 | 1242.000000 | 3.0 |
m = 20
height = int(min(dims_df['height'])/m)*m
width = int(min(dims_df['width'])/m)*m
# height = 2**(int(np.log2(min(dims_df['height']))))
# width = 2**(int(np.log2(min(dims_df['width']))))
depth = min(dims_df['depth'])
height,width = min(height,width),min(height,width)
height,width,depth
(360, 360, 3)
for t in train_ds.take(3):
print(t.keys())
dict_keys(['image', 'image/file_name', 'objects']) dict_keys(['image', 'image/file_name', 'objects']) dict_keys(['image', 'image/file_name', 'objects'])
if dataset_name == 'bdd100k':
train_ds = train_ds.map(lambda x0,x1: x0/255.)
test_ds = test_ds.map(lambda x0,x1: x0/255.)
validation_ds = validation_ds.map(lambda x0,x1: x0/255.)
else:
train_ds = train_ds.map(lambda x: tf.image.resize(images=tf.cast(x['image'],dtype=tf.float32)/255.,\
size=[height,width]))
train_ds = train_ds.batch(batch_size,drop_remainder=True)
###
test_ds = test_ds.map(lambda x: tf.image.resize(tf.cast(x['image'],dtype=tf.float32)/255.,\
size=[height,width]))
test_ds = test_ds.batch(batch_size,drop_remainder=True)
###
validation_ds = validation_ds.map(lambda x: tf.image.resize(tf.cast(x['image'],dtype=tf.float32)/255.\
,size=[height,width]))
validation_ds = validation_ds.batch(batch_size,drop_remainder=True)
###
train_ds_double_zipped = tf.data.Dataset.zip(datasets=(train_ds,train_ds))
test_ds_double_zipped = tf.data.Dataset.zip(datasets=(test_ds,test_ds))
validation_ds_double_zipped = tf.data.Dataset.zip(datasets=(validation_ds,validation_ds))
latent_dim = 128
class CAE(tf.keras.Model):
"""Convolutional variational autoencoder."""
def __init__(self, latent_dim):
super(CAE, self).__init__()
self.latent_dim = latent_dim
self.logger = logging.getLogger('CAE')
self.encoder = tf.keras.Sequential(name='encoder',layers=\
[
tf.keras.layers.InputLayer(input_shape=(height, width, depth)),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Flatten(),
# No activation
tf.keras.layers.Dense(latent_dim),
]
)
self.decoder = tf.keras.Sequential(name='decoder',layers=\
[
tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
tf.keras.layers.Dense(units=int(height/4) * int(width/4) * 32, activation=tf.nn.relu),
tf.keras.layers.Reshape(target_shape=(int(height/4), int(width/4), 32)),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=3, strides=2, padding='same',
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=3, strides=2, padding='same',
activation='relu'),
# No activation
tf.keras.layers.Conv2DTranspose(
filters=depth, kernel_size=3, strides=1, padding='same'),
]
)
self.encoder.summary()
self.decoder.summary()
def call(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
cae = CAE(latent_dim)
cae.compile(optimizer='adam', loss=losses.MeanSquaredError())
Model: "encoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 179, 179, 32) 896 _________________________________________________________________ conv2d_1 (Conv2D) (None, 89, 89, 64) 18496 _________________________________________________________________ flatten (Flatten) (None, 506944) 0 _________________________________________________________________ dense (Dense) (None, 128) 64888960 ================================================================= Total params: 64,908,352 Trainable params: 64,908,352 Non-trainable params: 0 _________________________________________________________________ Model: "decoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 259200) 33436800 _________________________________________________________________ reshape (Reshape) (None, 90, 90, 32) 0 _________________________________________________________________ conv2d_transpose (Conv2DTran (None, 180, 180, 64) 18496 _________________________________________________________________ conv2d_transpose_1 (Conv2DTr (None, 360, 360, 32) 18464 _________________________________________________________________ conv2d_transpose_2 (Conv2DTr (None, 360, 360, 3) 867 ================================================================= Total params: 33,474,627 Trainable params: 33,474,627 Non-trainable params: 0 _________________________________________________________________
model_file_path = f'./models/cae_dataset_{dataset_name}_z_dim_{latent_dim}_data_dim_{height}x{width}x{depth}'
print(f'model path = {model_file_path}')
model path = ./models/cae_dataset_kitti_z_dim_128_data_dim_360x360x3
if os.path.exists(model_file_path):
print('loading saved model')
cae = tf.keras.models.load_model(filepath=model_file_path)
else:
print('building model')
# use checkpoints to save model fitting progress
# https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint
checkpoint_filepath = './checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_loss',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
cae.fit(x=train_ds_double_zipped,validation_data=test_ds_double_zipped,epochs=epochs,\
callbacks=[model_checkpoint_callback],use_multiprocessing=True)
# The model weights (that are considered the best) are loaded into the model.
cae.load_weights(checkpoint_filepath)
print('saving model')
cae.save(filepath=model_file_path)
building model Epoch 1/10 198/198 [==============================] - 459s 2s/step - loss: 0.1017 - val_loss: 0.0429 Epoch 2/10 198/198 [==============================] - 435s 2s/step - loss: 0.0446 - val_loss: 0.0341 Epoch 3/10 198/198 [==============================] - 435s 2s/step - loss: 0.0355 - val_loss: 0.0308 Epoch 4/10 198/198 [==============================] - 961s 5s/step - loss: 0.0312 - val_loss: 0.0289 Epoch 5/10 198/198 [==============================] - 458s 2s/step - loss: 0.0282 - val_loss: 0.0284 Epoch 6/10 198/198 [==============================] - 449s 2s/step - loss: 0.0263 - val_loss: 0.0282 Epoch 7/10 198/198 [==============================] - 438s 2s/step - loss: 0.0253 - val_loss: 0.0267 Epoch 8/10 198/198 [==============================] - 884s 4s/step - loss: 0.0238 - val_loss: 0.0264 Epoch 9/10 198/198 [==============================] - 455s 2s/step - loss: 0.0233 - val_loss: 0.0269 Epoch 10/10 198/198 [==============================] - 439s 2s/step - loss: 0.0231 - val_loss: 0.0275 saving model INFO:tensorflow:Assets written to: ./models/cae_dataset_kitti_z_dim_128_data_dim_360x360x3/assets
INFO:tensorflow:Assets written to: ./models/cae_dataset_kitti_z_dim_128_data_dim_360x360x3/assets
# create valdation dataset tensor
for e in validation_ds.take(1):
initial_state = tf.zeros(dtype=tf.float32,shape=e.shape)
validation_ds_tensor = validation_ds.\
reduce(initial_state=initial_state,reduce_func=lambda x,y: tf.concat(values=[x,y],axis=0))
validation_ds_tensor = validation_ds_tensor[batch_size:] # drop dummy initial state
# calculate loss, can be compare over different dataset due to data scaling from 0 to 1
y_predicted = cae.predict(validation_ds)
cae_loss = cae.loss(y_pred=y_predicted,y_true=validation_ds_tensor).numpy()
print(f'CAE loss for dataset {dataset_name} = {np.round(cae_loss,4)}')
CAE loss for dataset kitti = 0.042899999767541885
# plot decoded images
for batch in validation_ds.take(2):
z = cae.encoder(batch).numpy()
decoded_imgs = cae.decoder(z).numpy()
for i in range(batch.shape[0]):
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(batch[i])
ax2.imshow(decoded_imgs[i])
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:8: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
# getting z tensor
z_tensor = None
inf_or_unknown_cardinality = ((test_ds.cardinality()==tf.data.INFINITE_CARDINALITY)\
or (test_ds.cardinality() == tf.data.UNKNOWN_CARDINALITY)).numpy()
batches = test_ds.cardinality().numpy() if not inf_or_unknown_cardinality else 500
with tqdm(total=batches) as pbar:
for batch in test_ds.take(batches):
z = cae.encoder(batch).numpy()
if z_tensor is None:
z_tensor = tf.convert_to_tensor(z)
else:
z_tensor = tf.concat([z_tensor,tf.convert_to_tensor(z)],axis=0)
pbar.update(1)
#print(f'z shape {z.shape}')
# decoded_imgs = cae.decoder(z).numpy()
# #print(f'decoded images shape {decoded_imgs[0].shape}')
# plt.imshow(batch[0])
# plt.show()
# plt.imshow(decoded_imgs[0])
# plt.show()
z_tensor.shape
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:8: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0 Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
TensorShape([704, 128])
z_np= z_tensor.numpy()
n_z = z_np.shape[0]
n_z_train = int(0.8*n_z)
z_train = z_np[:n_z_train]
z_test = z_np[n_z_train:]
random_state = 1
reg_covar = 0.1
cov_type = 'diag'
print(f"""For Dateset "{dataset_name}" Calculating relative difference of log likelihood """)
print(f'Latent_dim = {latent_dim}, Gaussiam Mixture covariance type = {cov_type} and reg_covar = {reg_covar} ')
print('############################ ')
g_fit = GaussianMixture(n_components=1,covariance_type=cov_type,random_state=1,reg_covar=reg_covar).fit(z_train)
logp_g = g_fit.score(X=z_test)
for k in [10,20,50,70,80,100,200]:
try:
gm_fit = GaussianMixture(n_components=k,covariance_type=cov_type,random_state=random_state,\
reg_covar=reg_covar).fit(z_train)
logp_gm = gm_fit.score(X=z_test)
rel_diff_logps = (logp_gm- logp_g) / np.abs(logp_g)
print(f'logp Gaussin Mixture with k = {k} = {logp_gm} ')
print(f'logp Gaussian Diagonal = {logp_g} ')
print(f'At k = {k} , rel_diff for logps = {rel_diff_logps} ')
print('############## ')
except Exception as e:
print(f'Catched expection {e} ')
For Dateset "kitti" Calculating relative difference of log likelihood Latent_dim = 128, Gaussiam Mixture covariance type = diag and reg_covar = 0.1 ############################ logp Gaussin Mixture with k = 10 = -170.6917955997237 logp Gaussian Diagonal = -262.341110386977 At k = 10 , rel_diff for logps = 0.3493517072183701 ############## logp Gaussin Mixture with k = 20 = -137.54189665032573 logp Gaussian Diagonal = -262.341110386977 At k = 20 , rel_diff for logps = 0.475713522568236 ############## logp Gaussin Mixture with k = 50 = -113.42024058446295 logp Gaussian Diagonal = -262.341110386977 At k = 50 , rel_diff for logps = 0.567661201032664 ############## logp Gaussin Mixture with k = 70 = -103.88851239114166 logp Gaussian Diagonal = -262.341110386977 At k = 70 , rel_diff for logps = 0.6039945388738477 ############## logp Gaussin Mixture with k = 80 = -101.45960509411145 logp Gaussian Diagonal = -262.341110386977 At k = 80 , rel_diff for logps = 0.6132531232163754 ############## logp Gaussin Mixture with k = 100 = -100.49451496007462 logp Gaussian Diagonal = -262.341110386977 At k = 100 , rel_diff for logps = 0.6169318837911599 ############## logp Gaussin Mixture with k = 200 = -97.06446524302851 logp Gaussian Diagonal = -262.341110386977 At k = 200 , rel_diff for logps = 0.6300066539329288 ##############
random_state = 1
reg_covar = 0.1
cov_type = 'full'
print(f"""For Dateset "{dataset_name}" Calculating relative difference of log likelihood """)
print(f'Latent_dim = {latent_dim}, Gaussiam Mixture covariance type = {cov_type} and reg_covar = {reg_covar} ')
print('############################ ')
g_fit = GaussianMixture(n_components=1,covariance_type=cov_type,random_state=1,reg_covar=reg_covar).fit(z_train)
logp_g = g_fit.score(X=z_test)
for k in [10,20,50,70,80,100,200]:
try:
gm_fit = GaussianMixture(n_components=k,covariance_type=cov_type,random_state=random_state,\
reg_covar=reg_covar).fit(z_train)
logp_gm = gm_fit.score(X=z_test)
rel_diff_logps = (logp_gm- logp_g) / np.abs(logp_g)
print(f'logp Gaussin Mixture with k = {k} = {logp_gm} ')
print(f'logp Gaussian Diagonal = {logp_g} ')
print(f'At k = {k} , rel_diff for logps = {rel_diff_logps} ')
print('############## ')
except Exception as e:
print(f'Catched expection {e} ')
For Dateset "kitti" Calculating relative difference of log likelihood Latent_dim = 128, Gaussiam Mixture covariance type = full and reg_covar = 0.1 ############################ logp Gaussin Mixture with k = 10 = -1.327926347027144 logp Gaussian Diagonal = -11.438871900903225 At k = 10 , rel_diff for logps = 0.8839110745770052 ############## logp Gaussin Mixture with k = 20 = -0.37018736953522263 logp Gaussian Diagonal = -11.438871900903225 At k = 20 , rel_diff for logps = 0.967637772960287 ############## logp Gaussin Mixture with k = 50 = -6.77506646715561 logp Gaussian Diagonal = -11.438871900903225 At k = 50 , rel_diff for logps = 0.40771550500354464 ############## logp Gaussin Mixture with k = 70 = -13.593873988321024 logp Gaussian Diagonal = -11.438871900903225 At k = 70 , rel_diff for logps = -0.18839288577465735 ############## logp Gaussin Mixture with k = 80 = -16.753749716347343 logp Gaussian Diagonal = -11.438871900903225 At k = 80 , rel_diff for logps = -0.4646330391220178 ############## logp Gaussin Mixture with k = 100 = -28.776018292355722 logp Gaussian Diagonal = -11.438871900903225 At k = 100 , rel_diff for logps = -1.5156342812164494 ############## logp Gaussin Mixture with k = 200 = -65.05804713056555 logp Gaussian Diagonal = -11.438871900903225 At k = 200 , rel_diff for logps = -4.687453071786607 ##############